# -*- coding: utf-8 -*-
"""
Created on Wed Nov  8 15:27:55 2017

@author: xuanlei
"""

import pandas as pd
import tensorflow as tf
import numpy as np
from sklearn import metrics
#==============================================================================
# Batch Normalization
#==============================================================================
def batch_norm_layer(x, train_phase, scope_bn):
    with tf.variable_scope(scope_bn):
        beta = tf.Variable(tf.constant(0.0, shape=[x.shape[-1]]), name='beta', trainable=True)
        gamma = tf.Variable(tf.constant(1.0, shape=[x.shape[-1]]), name='gamma', trainable=True)
        axises = np.arange(len(x.shape) - 1)
        batch_mean, batch_var = tf.nn.moments(x, axises, name='moments')
        ema = tf.train.ExponentialMovingAverage(decay=0.5)

        def mean_var_with_update():
            ema_apply_op = ema.apply([batch_mean, batch_var])
            with tf.control_dependencies([ema_apply_op]):
                return tf.identity(batch_mean), tf.identity(batch_var)

        mean, var = tf.cond(train_phase, mean_var_with_update,
                            lambda: (ema.average(batch_mean), ema.average(batch_var)))
        normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, 1e-3)
    return normed

def xavier_init(fan_in, fan_out, constant=1):
    """xavier initialization function
    fan_in: input node number
    fan_out: output node number"""
    low = -constant*np.sqrt(6/(fan_in+fan_out))
    high = constant*np.sqrt(6/(fan_in+fan_out))
    return tf.random_uniform((fan_in, fan_out), minval=low, maxval=high, dtype=tf.float32)

#==============================================================================
# RNN Structure
#==============================================================================
class LSTMRNN():
    
    #initial setting
    def __init__(self,input_size, output_size, h1_size, h2_size, h3_size, LR, batch_size,
                 transfer = tf.nn.tanh):
        self.input_size = input_size
        self.output_size = output_size
        self.h1_size = h1_size
        self.h2_size = h2_size
        self.h3_size = h3_size
        self.batch_size = batch_size
        self.LR = LR
        self.transfer = transfer
        
        with tf.name_scope('inputs'):
            self.xs = tf.placeholder(tf.float32, [None, input_size], name='xs')
            self.ys = tf.placeholder(tf.float32, [None, output_size], name='ys')
            self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')
            self.train_phase = tf.placeholder(tf.bool, name='train_phase')

            
            
        with tf.name_scope('hidden_1'):
            self.add_h1_layer()
            
        with tf.name_scope('hidden_2'):
            self.add_h2_layer()

        with tf.name_scope('hidden_3'):
            self.add_h3_layer()
        with tf.name_scope('out_hidden'):
            self.add_output_layer()
    
        with tf.name_scope('cost'):
            self.compute_cost()

        with tf.name_scope('train'):
            self.train_op = tf.train.AdamOptimizer(learning_rate=self.LR).minimize(self.cost)

#    def add_input_layer(self):
#        with tf.name_scope('input_layer'):
#            l_in_x = tf.reshape(self.xs,[-1,self.input_size], name='x_input')
#            Ws_in = tf.Variable(xavier_init(self.input_size, self.cell_size),name = "W")
#            bs_in = tf.Variable(tf.zeros([self.cell_size,])+0.01)
#            l_in_y = tf.matmul(l_in_x,Ws_in)+bs_in
#            self.l_in_y = tf.reshape(l_in_y,[-1,self.n_steps,self.cell_size],name='cell_input')


    def add_h1_layer(self):
        with tf.name_scope('h1_layer'):   
            h1_x = tf.reshape(self.xs,[-1,self.input_size], name='x_input')
            #Ws_h1 = tf.Variable(tf.truncated_normal([self.cell_size, self.h1_size], mean=3, stddev=1))
#            Ws_h1 = tf.get_variable("W1", shape=[self.cell_size, self.h1_size],initializer=tf.contrib.layers.xavier_initializer())
            Ws_h1 = tf.Variable(xavier_init(self.input_size, self.h1_size),name = "W1")
            bs_h1 = tf.Variable(tf.zeros([self.h1_size,])+0.01)
            non_bn_h1 = self.transfer(tf.matmul(h1_x,Ws_h1)+bs_h1)
#            non_bn_h1 = tf.matmul(h1_x,Ws_h1)+bs_h1
            self.h1_y = batch_norm_layer(non_bn_h1, train_phase=self.train_phase, scope_bn='bn_h1')
        

    def add_h2_layer(self):
        with tf.name_scope('h2_layer'):
            h2_x = tf.reshape(self.h1_y, [-1,self.h1_size])
#            Ws_h2 = tf.Variable(tf.truncated_normal([self.h1_size, self.h2_size], mean=3, stddev=2))
#            Ws_h2 = tf.get_variable("W2", shape=[self.h1_size, self.h2_size],initializer=tf.contrib.layers.xavier_initializer())
            Ws_h2 = tf.Variable(xavier_init(self.h1_size, self.h2_size),name = "W2")
            bs_h2 = tf.Variable(tf.zeros([self.h2_size,])+0.01)
            non_bn_h2 = self.transfer(tf.matmul(h2_x,Ws_h2)+bs_h2)
#            non_bn_h2 = tf.matmul(h2_x,Ws_h2)+bs_h2
            self.h2_y = batch_norm_layer(non_bn_h2, train_phase=self.train_phase, scope_bn='bn_h2')

    def add_h3_layer(self):
        with tf.name_scope('h3_layer'):
            h3_x = tf.reshape(self.h2_y, [-1,self.h2_size])
            #Ws_h3 = tf.Variable(tf.truncated_normal([self.h2_size, self.h3_size], mean=3, stddev=2))
#            Ws_h3 = tf.get_variable("W3", shape=[self.h2_size, self.h3_size],initializer=tf.contrib.layers.xavier_initializer())
            Ws_h3 = tf.Variable(xavier_init(self.h2_size, self.h3_size),name = "W3")
            bs_h3 = tf.Variable(tf.zeros([self.h3_size,])+0.01)
            non_bn_h3 = self.transfer(tf.matmul(h3_x,Ws_h3)+bs_h3)
#            non_bn_h3 = tf.matmul(h3_x,Ws_h3)+bs_h3
            self.h3_y = batch_norm_layer(non_bn_h3, train_phase=self.train_phase, scope_bn='bn_h3')

    def add_output_layer(self):
        layer_name='output_layer'
        with tf.name_scope('output_layer'):
            l_out_x = tf.reshape(self.h3_y,[-1,self.h3_size],name = 'y_input')
            #Ws_out = tf.Variable(tf.truncated_normal([self.h3_size, self.output_size], mean=3, stddev=1))
#            Ws_out = tf.get_variable("W4", shape=[self.h3_size, self.output_size],initializer=tf.contrib.layers.xavier_initializer())
            Ws_out = tf.Variable(xavier_init(self.h3_size, self.output_size),name = "W4")
            bs_out = tf.Variable(tf.zeros([self.output_size,]))
            #self.pred = tf.matmul(l_out_x,Ws_out)+bs_out
            self.pred = tf.nn.sigmoid(tf.matmul(l_out_x,Ws_out)+bs_out)
            #self.pred = tf.matmul(l_out_x,Ws_out)+bs_out
#            self.pred = tf.nn.softmax(tf.matmul(l_out_x,Ws_out)+bs_out)
            tf.summary.histogram('w', Ws_out)
            tf.summary.histogram('b', bs_out)
            tf.summary.histogram('out', self.pred)

#===============================================================================
# 交叉熵
#===============================================================================
    def compute_cost(self):
        with tf.name_scope('loss'):
            # self.cost = tf.sqrt(tf.reduce_mean(tf.square(tf.subtract(self.pred, self.ys))))
            self.cost = np.sqrt(metrics.mean_squared_error(self.ys, self.pred)))
            tf.summary.scalar('result_cost', self.cost)